# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Frame utils for MegaSaM."""

# pylint: disable=invalid-name
# pylint: disable=g-doc-args
# pylint: disable=broad-exception-raised

import os
import re

import cv2
import numpy as np
from PIL import Image

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

TAG_CHAR = np.array([202021.25], np.float32)


def readFlow(fn):
  """Read .flo file in Middlebury format."""
  # Code adapted from:
  # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy

  # WARNING: this will work on little-endian architectures (eg Intel x86) only!
  # print 'fn = %s'%(fn)
  with open(fn, 'rb') as f:
    magic = np.fromfile(f, np.float32, count=1)
    if 202021.25 != magic:
      print('Magic number incorrect. Invalid .flo file')
      return None
    else:
      w = np.fromfile(f, np.int32, count=1)
      h = np.fromfile(f, np.int32, count=1)
      # print 'Reading %d x %d flo file\n' % (w, h)
      data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
      # Reshape data into 3D array (columns, rows, bands)
      # The reshape here is for visualization, the original code is (w,h,2)
      return np.resize(data, (int(h), int(w), 2))


def readPFM(file):
  """Read PFM file."""
  file = open(file, 'rb')

  header = file.readline().rstrip()
  if header == b'PF':
    color = True
  elif header == b'Pf':
    color = False
  else:
    raise Exception('Not a PFM file.')

  dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
  if dim_match:
    width, height = map(int, dim_match.groups())
  else:
    raise Exception('Malformed PFM header.')

  scale = float(file.readline().rstrip())
  if scale < 0:  # little-endian
    endian = '<'
  else:
    endian = '>'  # big-endian

  data = np.fromfile(file, endian + 'f')
  shape = (height, width, 3) if color else (height, width)

  data = np.reshape(data, shape)
  data = np.flipud(data)
  return data


def writeFlow(filename, uv, v=None):
  """Write optical flow to file.

  If v is None, uv is assumed to contain both u and v channels,
  stacked in depth.
  Original code by Deqing Sun, adapted from Daniel Scharstein.
  """
  nBands = 2

  if v is None:
    assert uv.ndim == 3
    assert uv.shape[2] == 2
    u = uv[:, :, 0]
    v = uv[:, :, 1]
  else:
    u = uv

  assert u.shape == v.shape
  height, width = u.shape
  f = open(filename, 'wb')
  # write the header
  f.write(TAG_CHAR)
  np.array(width).astype(np.int32).tofile(f)
  np.array(height).astype(np.int32).tofile(f)
  # arrange into matrix form
  tmp = np.zeros((height, width * nBands))
  tmp[:, np.arange(width) * 2] = u
  tmp[:, np.arange(width) * 2 + 1] = v
  tmp.astype(np.float32).tofile(f)
  f.close()


def readFlowKITTI(filename):
  flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
  flow = flow[:, :, ::-1].astype(np.float32)
  flow, valid = flow[:, :, :2], flow[:, :, 2]
  flow = (flow - 2**15) / 64.0
  return flow, valid


def readDispKITTI(filename):
  disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
  valid = disp > 0.0
  flow = np.stack([-disp, np.zeros_like(disp)], -1)
  return flow, valid


def writeFlowKITTI(filename, uv):
  uv = 64.0 * uv + 2**15
  valid = np.ones([uv.shape[0], uv.shape[1], 1])
  uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
  cv2.imwrite(filename, uv[..., ::-1])


def read_gen(file_name, pil=False):
  """Read image or flow file."""
  del pil
  ext = os.path.splitext(file_name)[-1]
  if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
    return Image.open(file_name)
  elif ext == '.bin' or ext == '.raw':
    return np.load(file_name)
  elif ext == '.flo':
    return readFlow(file_name).astype(np.float32)  # pylint: disable=attribute-error
  elif ext == '.pfm':
    flow = readPFM(file_name).astype(np.float32)
    if len(flow.shape) == 2:
      return flow
    else:
      return flow[:, :, :-1]
  return []
